package com.yeskery.nut.extend.jdbc;

import com.yeskery.nut.util.JdbcUtils;

import javax.sql.DataSource;
import java.sql.*;
import java.util.Spliterator;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

/**
 * 分页SQL结果集分裂器
 * @author YESKERY
 * 2024/8/29
 *
 * @param <T> 结果类型
 */
public class PageSqlResultSetSpliterator<T> implements Spliterator<T> {

    /** JDBC操作模板 */
    private final JdbcTemplate jdbcTemplate;

    /** sql */
    private final String sql;

    /** 每页的条数 */
    private final int size;

    /** 总条数 */
    private final long total;

    /** 总页数 */
    private final long pages;

    /** 当前页数 */
    private final AtomicLong currentPage;

    /** 结果集回调对象 */
    private ResultSetRowCallback<T> callback;

    /** sql参数值 */
    private Object[] params;

    /** 连接对象 */
    private Connection connection;

    /** Statement */
    private Statement statement;

    /** PreparedStatement */
    private PreparedStatement preparedStatement;

    /** 当前结果集 */
    private ResultSet resultSet;

    /**
     * 构建分页SQL结果集分裂器
     * @param jdbcTemplate JDBC操作模板
     * @param sql sql
     * @param size 每页的条数
     * @param callback 结果集回调对象
     */
    public PageSqlResultSetSpliterator(JdbcTemplate jdbcTemplate, String sql, int size, ResultSetRowCallback<T> callback) {
        this.jdbcTemplate = jdbcTemplate;
        this.sql = sql;
        this.size = size;
        this.callback = callback;
        this.total = jdbcTemplate.query(getTotalCountSql(sql), Long.class);
        this.pages = total % size == 0 ? total / size : total / size + 1;
        this.currentPage = new AtomicLong(1);
        jdbcTemplate.execute((ConnectionCallback<Object>) connection -> {
            setConnection(connection);
            return null;
        });
        this.resultSet = getResultSet();
    }

    /**
     * 构建分页SQL结果集分裂器
     * @param jdbcTemplate JDBC操作模板
     * @param sql sql
     * @param params 参数
     * @param size 每页的条数
     * @param callback 结果集回调对象
     */
    public PageSqlResultSetSpliterator(JdbcTemplate jdbcTemplate, String sql, Object[] params, int size, ResultSetRowCallback<T> callback) {
        this.jdbcTemplate = jdbcTemplate;
        this.sql = sql;
        this.params = params;
        this.size = size;
        this.callback = callback;
        this.total = jdbcTemplate.query(getTotalCountSql(sql), params, Long.class);
        this.pages = total % size == 0 ? total / size : total / size + 1;
        this.currentPage = new AtomicLong(1);
        jdbcTemplate.execute((ConnectionCallback<Object>) connection -> {
            setConnection(connection);
            return null;
        });
        this.resultSet = getResultSet();
    }

    /**
     * 构建分页SQL结果集分裂器
     * @param jdbcTemplate JDBC操作模板
     * @param sql sql
     * @param size 每页的条数
     * @param clazz 结果类型
     */
    @SuppressWarnings("unchecked")
    public PageSqlResultSetSpliterator(JdbcTemplate jdbcTemplate, String sql, int size, Class<T> clazz) {
        this(jdbcTemplate, sql, size, rs -> (T) JdbcUtils.getResultSetValue(rs, 1, clazz));
    }

    /**
     * 构建分页SQL结果集分裂器
     * @param jdbcTemplate JDBC操作模板
     * @param sql sql
     * @param params 参数
     * @param size 每页的条数
     * @param clazz 结果类型
     */
    @SuppressWarnings("unchecked")
    public PageSqlResultSetSpliterator(JdbcTemplate jdbcTemplate, String sql, Object[] params, int size, Class<T> clazz) {
        this(jdbcTemplate, sql, params, size, rs -> (T) JdbcUtils.getResultSetValue(rs, 1, clazz));
    }

    @Override
    public boolean tryAdvance(Consumer<? super T> action) {
        if (total <= 0) {
            return false;
        }
        try {
            if (resultSet.next()) {
                action.accept(callback.callback(resultSet));
                return true;
            }
            if (currentPage.longValue() < pages) {
                currentPage.incrementAndGet();

                resultSet = getResultSet();
                if (resultSet.next()) {
                    action.accept(callback.callback(resultSet));
                    return true;
                }
            }
            if (resultSet != null && !resultSet.isClosed()) {
                resultSet.close();
            }
            if (preparedStatement != null && !preparedStatement.isClosed()) {
                preparedStatement.close();
            }
            if (statement != null && !statement.isClosed()) {
                statement.close();
            }
            if (!(jdbcTemplate.getDataSource() instanceof WrapCacheDataSource)) {
                JdbcUtils.closeConnection(connection);
            }
            return false;
        } catch (SQLException e) {
            throw new DataAccessException("JDBC SQL Execute Fail.", e);
        }
    }

    @Override
    public Spliterator<T> trySplit() {
        return null;
    }

    @Override
    public long estimateSize() {
        return Long.MAX_VALUE;
    }

    @Override
    public int characteristics() {
        return Spliterator.ORDERED;
    }

    /**
     * 生成Stream流
     * @return Stream流
     */
    public Stream<T> stream() {
        return StreamSupport.stream(this, false).onClose(() -> {
            JdbcUtils.closeResultSet(resultSet);
            JdbcUtils.closeStatement(statement);
            JdbcUtils.closeStatement(preparedStatement);
            DataSource dataSource;
            if ((dataSource = jdbcTemplate.getDataSource()) instanceof WrapBasicDataSource) {
                JdbcUtils.closeConnection(connection);
            } else if (dataSource instanceof WrapCacheDataSource) {
                if (!jdbcTemplate.isActive()) {
                    ((WrapCacheDataSource) dataSource).free(connection);
                }
            } else {
                JdbcUtils.closeConnection(connection);
            }
        });
    }

    /**
     * 设置结果集回调对象
     * @param callback 结果集回调对象
     */
    protected void setResultSetRowCallback(ResultSetRowCallback<T> callback) {
        this.callback = callback;
    }

    /**
     * 获取总条数sql
     * @param sql sql
     * @return 总条数sql
     */
    protected String getTotalCountSql(String sql) {
        return "SELECT COUNT(1) FROM ("+ sql +") AS T";
    }

    /**
     * 获取分页SQL
     * @param sql sql
     * @param currentPage 当前页
     * @param size 每页显示的条数
     * @return 分页SQL
     */
    protected String getPageOffsetSql(String sql, long currentPage, long size) {
        return sql + " LIMIT " + ((currentPage - 1) * size) + ", " + size;
    }

    /**
     * 设置连接对象
     * @param connection 连接对象
     */
    private void setConnection(Connection connection) {
        this.connection = connection;
    }

    /**
     * 获取结果集对象
     * @return 结果集对象
     */
    private ResultSet getResultSet() {
        try {
            String pageQuerySql = getPageOffsetSql(sql, currentPage.longValue(), size);
            if (resultSet != null && !resultSet.isClosed()) {
                resultSet.close();
            }
            if (params == null) {
                if (statement != null && !statement.isClosed()) {
                    statement.close();
                }
                statement = connection.createStatement();
                return statement.executeQuery(pageQuerySql);
            } else {
                if (preparedStatement != null && !preparedStatement.isClosed()) {
                    preparedStatement.close();
                }
                preparedStatement = connection.prepareStatement(pageQuerySql);
                for (int i = 0; i < params.length; i++) {
                    preparedStatement.setObject(i + 1, params[i]);
                }
                return preparedStatement.executeQuery();
            }
        } catch (SQLException e) {
            throw new DataAccessException("JDBC SQL Execute Fail.", e);
        }
    }
}
